# Definition for a binary tree node.
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def pruneTree(self, root):
        def is_all_zero(root):
            if root is None:
                return True

            if root.val == 1:
                return False
            
            left = is_all_zero(root.left)
            right = is_all_zero(root.right)
            return left and right
        
        def dfs(root):
            if is_all_zero(root.left):
                root.left = None
            
            if is_all_zero(root.right):
                root.right = None
            
            if root.val == 0:
                if not root.left:
                    root = root.right
                    return root
                elif not root.right:
                    root = root.left
                    return root

            if root and root.left:
                root.left = dfs(root.left)
            
            if root and root.right:
                root.right = dfs(root.right)
            
            return root
        
        root = dfs(root)
        return root
    

if __name__ == "__main__":
    s = Solution()
    root = TreeNode(0, 
            None,
            TreeNode(0, 
                TreeNode(0), 
                TreeNode(0)))
    t = s.pruneTree(root)
    pass